{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "# 加载 mnist_inference 中定义的常理和前向传播的函数。\n",
    "import mnist_inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#配置神经网络的参数\n",
    "BATCH_SIZE = 100\n",
    "LEARNING_RATE_BASE = 0.8\n",
    "LEARNING_RATE_DECAY = 0.99\n",
    "REGULARAZTION_RATE = 0.0001\n",
    "TRAINING_STEPS = 30000\n",
    "MOVING_AVERAGE_DECAY = 0.99\n",
    "#模型保存的路径和文件名。\n",
    "MODEL_SAVE_PATH = \"./model/\"\n",
    "MODEL_NAME = \"model.ckpt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def train(mnist):\n",
    "    #定义输入输出placeholder。\n",
    "    x = tf.placeholder(tf.float32,[None,mnist_inference.INPUT_NODE],name='x-input')\n",
    "    y_ = tf.placeholder(tf.float32,[None,mnist_inference.OUTPUT_NODE],name='y-input')\n",
    "    regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)\n",
    "    #直接使用mnist_inference.py 中定义的前向传播过程\n",
    "    y = mnist_inference.inference(x,regularizer)\n",
    "    global_step = tf.Variable(0,trainable=False)\n",
    "    # 定义损失函数、学习率、滑动平均操作以及训练过程。\n",
    "    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,global_step)\n",
    "    variables_averages_op = variable_averages.apply(tf.trainable_variables())\n",
    "    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(y,tf.argmax(y_,1))\n",
    "    cross_entropy_mean = tf.reduce_mean(cross_entropy)\n",
    "    loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))\n",
    "    learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,mnist.train.num_examples / BATCH_SIZE,LEARNING_RATE_DECAY)\n",
    "    train_step = tr.train.GradientDescentOptimizer(learning_rate.minimize(loss,global_step= global_step))\n",
    "    with tf.control_dependencies([train_step,variables_averages_op]):\n",
    "        train_op = tf.no_op(name='train')\n",
    "    #初始化TensorFlow 持久化类\n",
    "    saver = tf.train.Saver()\n",
    "    with tf.Session() as sess:\n",
    "        tf.global_variables_initializer().run()\n",
    "        #在训练过程中不再测试模型在验证数据上的表现，验证和测试的过程将会有一个独立的程序来完成\n",
    "        for i in range(TRAINING_STEPS):\n",
    "            xs,ys = mnist.train.next_batch(BATCH_SIZE)\n",
    "            _,loss_value,step = sess.run([train_op,loss,global_step],feed_dict={x:xs,y:ys})\n",
    "            #每1000轮保存一次模型\n",
    "            if i % 1000 == 0:\n",
    "                #输出当前的训练情况。这里只输出了模型在当前训练batch上的损失函数大小。\n",
    "                #通过损失函数的大小可以大概了解训练的情况。在验证数集上的正确率信息会有一个单独的程序来生成。\n",
    "                print(\"After %d training step(s),loss on training batch is %g.\" % (step,loss_value))\n",
    "                #保存当前的模型。注意这里给出了global_step 参数，这样可以让每个被保存模型的文件名末尾加上训练的轮数。\n",
    "                #比如\"model.ckpt-1000\" 表示训练1000轮之后得到的模型\n",
    "                saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME),global_step=global_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.\n",
      "Extracting ./tmp/data\\train-images-idx3-ubyte.gz\n",
      "Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.\n",
      "Extracting ./tmp/data\\train-labels-idx1-ubyte.gz\n",
      "Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.\n",
      "Extracting ./tmp/data\\t10k-images-idx3-ubyte.gz\n",
      "Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.\n",
      "Extracting ./tmp/data\\t10k-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'reqularizer' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-12-df61574b3c09>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m      3\u001b[0m     \u001b[0mtrain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmnist\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      4\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m'__main__'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 5\u001b[1;33m     \u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mapp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32mD:\\tool\\Miniconda3\\lib\\site-packages\\tensorflow\\python\\platform\\app.py\u001b[0m in \u001b[0;36mrun\u001b[1;34m(main, argv)\u001b[0m\n\u001b[0;32m     46\u001b[0m   \u001b[1;31m# Call the main function, passing through any arguments\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     47\u001b[0m   \u001b[1;31m# to the final program.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 48\u001b[1;33m   \u001b[0m_sys\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_sys\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0margv\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mflags_passthrough\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     49\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     50\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-12-df61574b3c09>\u001b[0m in \u001b[0;36mmain\u001b[1;34m(argv)\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mmain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margv\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      2\u001b[0m     \u001b[0mmnist\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minput_data\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mread_data_sets\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"./tmp/data\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mone_hot\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 3\u001b[1;33m     \u001b[0mtrain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmnist\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      4\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m'__main__'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      5\u001b[0m     \u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mapp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-10-68b76b7ba780>\u001b[0m in \u001b[0;36mtrain\u001b[1;34m(mnist)\u001b[0m\n\u001b[0;32m      5\u001b[0m     \u001b[0mregularizer\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcontrib\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlayers\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0ml2_regularizer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mREGULARAZTION_RATE\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      6\u001b[0m     \u001b[1;31m#直接使用mnist_inference.py 中定义的前向传播过程\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 7\u001b[1;33m     \u001b[0my\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmnist_inference\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minference\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mregularizer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      8\u001b[0m     \u001b[0mglobal_step\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mVariable\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mtrainable\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      9\u001b[0m     \u001b[1;31m# 定义损失函数、学习率、滑动平均操作以及训练过程。\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mE:\\DataMiningDemo\\tensorFlow\\mnistDemo\\mnist_inference.py\u001b[0m in \u001b[0;36minference\u001b[1;34m(input_tensor, regularizer)\u001b[0m\n\u001b[0;32m     44\u001b[0m         \u001b[1;31m#这里通过tf.get_variable 或 tf.Varibale 没有本质区别，因为在训练或测试中没有在同一个程序中多次调用这个函数。\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     45\u001b[0m         \u001b[1;31m#如果在同一个程序中多次调用，在第一次调用之后需要将resuse 参数设置为True。\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 46\u001b[1;33m         \u001b[0mweights\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mget_weight_variable\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mINPUT_NODE\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mLAYER1_NODE\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mreqularizer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     47\u001b[0m         \u001b[0mbiases\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_variable\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"biases\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mLAYER1_NODE\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0minitializer\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconstant_initializer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0.0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     48\u001b[0m         \u001b[0mlayer1\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput_tensor\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mweights\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mbiases\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mNameError\u001b[0m: name 'reqularizer' is not defined"
     ]
    }
   ],
   "source": [
    "def main(argv=None):\n",
    "    mnist = input_data.read_data_sets(\"./tmp/data\",one_hot=True)\n",
    "    train(mnist)\n",
    "if __name__ == '__main__':\n",
    "    tf.app.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
